{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Out Paint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras.layers.convolutional import Conv2D, AtrousConvolution2D\n",
    "from keras.layers import Activation, Dense, Input, Conv2DTranspose, Dense, Flatten\n",
    "from keras.layers import ReLU, Dropout, Concatenate, BatchNormalization, Reshape\n",
    "from keras.layers.advanced_activations import LeakyReLU\n",
    "from keras.models import Model, model_from_json\n",
    "from keras.optimizers import Adam\n",
    "from keras.layers.convolutional import UpSampling2D\n",
    "import os\n",
    "import numpy as np\n",
    "import PIL\n",
    "import IPython.display\n",
    "from keras_contrib.layers.normalization import InstanceNormalization\n",
    "from datetime import datetime\n",
    "from dataloader import Data\n",
    "import cv2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize dataloader\n",
    "data = Data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Saves Model in every N minutes\n",
    "TIME_INTERVALS = 1\n",
    "SHOW_SUMMARY = True\n",
    "\n",
    "INPUT_SHAPE = (256, 256, 3)\n",
    "EPOCHS = 500\n",
    "BATCH = 1\n",
    "\n",
    "# 25% i.e 64 width size will be mask from both side\n",
    "MASK_PERCENTAGE = .25\n",
    "\n",
    "\n",
    "CHECKPOINT = \"checkpoint/\"\n",
    "SAVED_IMAGES = \"saved_images/\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Discriminator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "d_input_shape = (INPUT_SHAPE[0], int(INPUT_SHAPE[1] * (MASK_PERCENTAGE *2)), INPUT_SHAPE[2])\n",
    "d_dropout = 0.25\n",
    "DCRM_OPTIMIZER = Adam(0.0001, 0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def d_build_conv(layer_input, filter_size, kernel_size=4, strides=2, activation='leakyrelu', dropout_rate=d_dropout, norm=True):\n",
    "    c = Conv2D(filter_size, kernel_size=kernel_size, strides=strides, padding='same')(layer_input)\n",
    "    if activation == 'leakyrelu':\n",
    "        c = LeakyReLU(alpha=0.2)(c)\n",
    "    if dropout_rate:\n",
    "        c = Dropout(dropout_rate)(c)\n",
    "    if norm == 'inst':\n",
    "        c = InstanceNormalization()(c)\n",
    "    return c\n",
    "\n",
    "\n",
    "def build_discriminator():\n",
    "    d_input = Input(shape=d_input_shape)\n",
    "    d = d_build_conv(d_input, 32, 5,strides=2, norm=False)\n",
    "    for i in range(4):\n",
    "        filter_size = 64\n",
    "        d = d_build_conv(d, filter_size, 5, strides=2)\n",
    "    flat = Flatten()(d)\n",
    "    fc1 = Dense(512, activation='relu')(flat)\n",
    "    d_output = Dense(1, activation='sigmoid')(fc1)\n",
    "    \n",
    "    return Model(d_input, d_output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Discriminator initialization\n",
    "DCRM = build_discriminator()\n",
    "DCRM.compile(loss='mse', optimizer=DCRM_OPTIMIZER)\n",
    "if SHOW_SUMMARY:\n",
    "    DCRM.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generator Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g_input_shape = (INPUT_SHAPE[0], int(INPUT_SHAPE[1] * (MASK_PERCENTAGE *2)), INPUT_SHAPE[2])\n",
    "g_dropout = 0.25\n",
    "GEN_OPTIMIZER = Adam(0.001, 0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def g_build_conv(layer_input, filter_size, kernel_size=4, strides=2, activation='leakyrelu', dropout_rate=g_dropout, norm='inst', dilation=1):\n",
    "    c = AtrousConvolution2D(filter_size, kernel_size=kernel_size, strides=strides,atrous_rate=(dilation,dilation), padding='same')(layer_input)\n",
    "    if activation == 'leakyrelu':\n",
    "        c = ReLU()(c)\n",
    "    if dropout_rate:\n",
    "        c = Dropout(dropout_rate)(c)\n",
    "    if norm == 'inst':\n",
    "        c = InstanceNormalization()(c)\n",
    "    return c\n",
    "\n",
    "\n",
    "def g_build_deconv(layer_input, filter_size, kernel_size=3, strides=2, activation='relu', dropout=0):\n",
    "    d = Conv2DTranspose(filter_size, kernel_size=kernel_size, strides=strides, padding='same')(layer_input)\n",
    "    if activation == 'relu':\n",
    "        d = ReLU()(d)\n",
    "    return d\n",
    "\n",
    "\n",
    "def build_generator():\n",
    "    g_input = Input(shape=g_input_shape)\n",
    "    \n",
    "    g1 = g_build_conv(g_input, 64, 5, strides=1)\n",
    "    g2 = g_build_conv(g1, 128, 4, strides=2)\n",
    "    g3 = g_build_conv(g2, 128, 4, strides=2)\n",
    "\n",
    "    g4 = g_build_conv(g3, 256, 4, strides=1)\n",
    "    g5 = g_build_conv(g4, 512, 3, strides=1, dilation=2)\n",
    "    g6 = g_build_conv(g5, 512, 3, strides=1, dilation=4)\n",
    "    g7 = g_build_conv(g6, 512, 3, strides=1, dilation=8)\n",
    "    g8 = g_build_conv(g7, 512, 3, strides=1, dilation=16)\n",
    "    \n",
    "    g9 = g_build_conv(g8, 256, 4, strides=1)\n",
    "    g10 = g_build_deconv(g9, 128, 4, strides=2)\n",
    "    g11 = g_build_deconv(g10, 128, 4, strides=2)\n",
    "    g12 = g_build_conv(g11, 64, 4, strides=1)\n",
    "    \n",
    "    g_output = AtrousConvolution2D(3, kernel_size=4, strides=(1,1), activation='tanh',padding='same', atrous_rate=(1,1))(g11)\n",
    "    \n",
    "    return Model(g_input, g_output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generator Initialization\n",
    "GEN = build_generator()\n",
    "GEN.compile(loss='mse', optimizer=GEN_OPTIMIZER)\n",
    "if SHOW_SUMMARY:\n",
    "    GEN.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Combined Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "IMAGE = Input(shape=g_input_shape)\n",
    "DCRM.trainable = False\n",
    "GENERATED_IMAGE = GEN(IMAGE)\n",
    "CONF_GENERATED_IMAGE = DCRM(GENERATED_IMAGE)\n",
    "\n",
    "COMBINED = Model(IMAGE, [CONF_GENERATED_IMAGE, GENERATED_IMAGE])\n",
    "COMBINED.compile(loss=['mse', 'mse'], optimizer=GEN_OPTIMIZER)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Masking and De-Masking"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mask_width(img):\n",
    "    image = img.copy()\n",
    "    height = image.shape[0]\n",
    "    width = image.shape[1]\n",
    "    new_width = int(width * MASK_PERCENTAGE)\n",
    "    mask = np.ones([height, new_width, 3])\n",
    "    missing_x = img[:, :new_width]\n",
    "    missing_y = img[:, width - new_width:]\n",
    "    missing_part = np.concatenate((missing_x, missing_y), axis=1)\n",
    "    image = image[:, :width - new_width]\n",
    "    image = image[:, new_width:]\n",
    "    return image, missing_part\n",
    "\n",
    "\n",
    "def get_masked_images(images):\n",
    "    mask_images = []\n",
    "    missing_images = []\n",
    "    for image in images:\n",
    "        mask_image, missing_image = mask_width(image)\n",
    "        mask_images.append(mask_image)\n",
    "        missing_images.append(missing_image)\n",
    "    return np.array(mask_images), np.array(missing_images)\n",
    "\n",
    "\n",
    "def get_demask_images(original_images, generated_images):\n",
    "    demask_images = []\n",
    "    for o_image, g_image in zip(original_images, generated_images):\n",
    "        width = g_image.shape[1] // 2\n",
    "        x_image = g_image[:, :width]\n",
    "        y_image = g_image[:, width:]\n",
    "        o_image = np.concatenate((x_image,o_image, y_image), axis=1)\n",
    "        demask_images.append(o_image)\n",
    "    return np.asarray(demask_images)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Masking, Demasking example\n",
    "# Note: IPython display gives false colors.\n",
    "x = data.get_data(1)\n",
    "\n",
    "# a will be the input and b will be the output for the model\n",
    "a, b = get_masked_images(x)\n",
    "border = np.ones([x[0].shape[0], 10, 3]).astype(np.uint8)\n",
    "print('After masking')\n",
    "print('\\tOriginal Image\\t\\t\\t a \\t\\t b')\n",
    "image = np.concatenate((border, x[0],border,a[0],border, b[0], border), axis=1)\n",
    "IPython.display.display(PIL.Image.fromarray(image))\n",
    "\n",
    "print(\"After desmasking: 'b/2' + a + 'b/2' \")\n",
    "c = get_demask_images(a,b)\n",
    "IPython.display.display(PIL.Image.fromarray(c[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Utilities\n",
    "1. Save Model\n",
    "2. Load Model\n",
    "3. Save Image\n",
    "4. Save Log"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_model():\n",
    "    global DCRM, GEN\n",
    "    models = [DCRM, GEN]\n",
    "    model_names = ['DCRM','GEN']\n",
    "\n",
    "    for model, model_name in zip(models, model_names):\n",
    "        model_path =  CHECKPOINT + \"%s.json\" % model_name\n",
    "        weights_path = CHECKPOINT + \"/%s.hdf5\" % model_name\n",
    "        options = {\"file_arch\": model_path, \n",
    "                    \"file_weight\": weights_path}\n",
    "        json_string = model.to_json()\n",
    "        open(options['file_arch'], 'w').write(json_string)\n",
    "        model.save_weights(options['file_weight'])\n",
    "    print(\"Saved Model\")\n",
    "    \n",
    "    \n",
    "def load_model():\n",
    "    # Checking if all the model exists\n",
    "    model_names = ['DCRM', 'GEN']\n",
    "    files = os.listdir(CHECKPOINT)\n",
    "    for model_name in model_names:\n",
    "        if model_name+\".json\" not in files or\\\n",
    "           model_name+\".hdf5\" not in files:\n",
    "            print(\"Models not Found\")\n",
    "            return\n",
    "    global DCRM, GEN, COMBINED, IMAGE, GENERATED_IMAGE, CONF_GENERATED_IMAGE\n",
    "    \n",
    "    # load DCRM Model\n",
    "    model_path = CHECKPOINT + \"%s.json\" % 'DCRM'\n",
    "    weight_path = CHECKPOINT + \"%s.hdf5\" % 'DCRM'\n",
    "    with open(model_path, 'r') as f:\n",
    "        DCRM = model_from_json(f.read())\n",
    "    DCRM.load_weights(weight_path)\n",
    "    DCRM.compile(loss='mse', optimizer=DCRM_OPTIMIZER)\n",
    "    \n",
    "    #load GEN Model\n",
    "    model_path = CHECKPOINT + \"%s.json\" % 'GEN'\n",
    "    weight_path = CHECKPOINT + \"%s.hdf5\" % 'GEN'\n",
    "    with open(model_path, 'r') as f:\n",
    "         GEN = model_from_json(f.read())\n",
    "    GEN.load_weights(weight_path)\n",
    "    \n",
    "    # Combined Model\n",
    "    DCRM.trainable = False\n",
    "    IMAGE = Input(shape=g_input_shape)\n",
    "    GENERATED_IMAGE = GEN(IMAGE)\n",
    "    CONF_GENERATED_IMAGE = DCRM(GENERATED_IMAGE)\n",
    "\n",
    "    COMBINED = Model(IMAGE, [CONF_GENERATED_IMAGE, GENERATED_IMAGE])\n",
    "    COMBINED.compile(loss=['mse', 'mse'], optimizer=GEN_OPTIMIZER)\n",
    "    \n",
    "    print(\"loaded model\")\n",
    "    \n",
    "    \n",
    "def save_image(epoch, steps):\n",
    "    original = data.get_data(1)\n",
    "    if original is None:\n",
    "        original = data.get_data(1)\n",
    "    \n",
    "    mask_image_original , missing_image = get_masked_images(original)\n",
    "    mask_image = mask_image_original.copy()\n",
    "    mask_image = mask_image / 127.5 - 1\n",
    "    missing_image = missing_image / 127.5 - 1\n",
    "    gen_missing = GEN.predict(mask_image)\n",
    "    gen_missing = (gen_missing + 1) * 127.5\n",
    "    gen_missing = gen_missing.astype(np.uint8)\n",
    "    demask_image = get_demask_images(mask_image_original, gen_missing)\n",
    "    \n",
    "    mask_image = (mask_image + 1) * 127.5\n",
    "    mask_image = mask_image.astype(np.uint8)\n",
    "\n",
    "    border = np.ones([original[0].shape[0], 10, 3]).astype(np.uint8)\n",
    "    \n",
    "    file_name = str(epoch) + \"_\" + str(steps) + \".jpg\"\n",
    "    final_image = np.concatenate((border, original[0],border,mask_image_original[0],border, demask_image[0], border), axis=1)\n",
    "    cv2.imwrite(os.path.join(SAVED_IMAGES, file_name), final_image)\n",
    "    print(\"\\t1.Original image \\t 2.Input \\t\\t 3. Output\")\n",
    "    IPython.display.display(PIL.Image.fromarray(final_image))\n",
    "    print(\"image saved\")\n",
    "\n",
    "\n",
    "def save_log(log):\n",
    "    with open('log.txt', 'a') as f:\n",
    "        f.write(\"%s\\n\"%log)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train():\n",
    "    start_time = datetime.now()\n",
    "    saved_time = start_time\n",
    "    \n",
    "    global MIN_D_LOSS, MIN_G_LOSS, CURRENT_D_LOSS, CURRENT_G_LOSS\n",
    "    for epoch in range(1, EPOCHS):\n",
    "        steps = 1\n",
    "        test = None\n",
    "        while True:\n",
    "            original = data.get_data(BATCH)\n",
    "            if original is None:\n",
    "                break\n",
    "            batch_size = original.shape[0]\n",
    "\n",
    "            mask_image, missing_image = get_masked_images(original)\n",
    "            mask_image = mask_image / 127.5 - 1\n",
    "            missing_image = missing_image / 127.5 - 1\n",
    "\n",
    "            # Train Discriminator\n",
    "            gen_missing = GEN.predict(mask_image)\n",
    "\n",
    "            real = np.ones([batch_size, 1])\n",
    "            fake = np.zeros([batch_size, 1])\n",
    "            \n",
    "            d_loss_original = DCRM.train_on_batch(missing_image, real)\n",
    "            d_loss_mask = DCRM.train_on_batch(gen_missing, fake)\n",
    "            d_loss = 0.5 * np.add(d_loss_original, d_loss_mask)\n",
    "\n",
    "            # Train Generator\n",
    "            for i in range(2):\n",
    "                g_loss = COMBINED.train_on_batch(mask_image, [real, missing_image])\n",
    "                    \n",
    "            log = \"epoch: %d, steps: %d, DIS loss: %s, GEN loss: %s, Identity loss: %s\" \\\n",
    "                                            %(epoch, steps, str(d_loss), str(g_loss[0]), str(g_loss[2]))\n",
    "            print(log)\n",
    "            save_log(log)\n",
    "            steps += 1\n",
    "            \n",
    "            # Save model if time taken > TIME_INTERVALS\n",
    "            current_time = datetime.now()\n",
    "            difference_time = current_time - saved_time\n",
    "            if difference_time.seconds >= (TIME_INTERVALS * 60):\n",
    "                save_model()\n",
    "                save_image(epoch, steps)\n",
    "                saved_time = current_time\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "load_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Recursive paint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "load_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def recursive_paint(image, factor=3):\n",
    "    final_image = None\n",
    "    gen_missing = None\n",
    "    for i in range(factor):\n",
    "        demask_image = None\n",
    "        if i == 0:\n",
    "            x, y = get_masked_images([image])\n",
    "            gen_missing = GEN.predict(x)\n",
    "            final_image = get_demask_images(x, gen_missing)[0]\n",
    "        else:\n",
    "            gen_missing = GEN.predict(gen_missing)\n",
    "            final_image = get_demask_images([final_image], gen_missing)[0]\n",
    "    return final_image\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "images = data.get_data(1)\n",
    "\n",
    "for i, image in enumerate(images):\n",
    "    image = image / 127.5 - 1\n",
    "    image = recursive_paint(image)\n",
    "    image = (image + 1) * 127.5\n",
    "    image = image.astype(np.uint8)\n",
    "    path = 'recursive/'+str(i)+'.jpg'\n",
    "    IPython.display.display(PIL.Image.fromarray(image))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
